{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nlpaug.augmenter.char as nac\n",
    "import nlpaug.augmenter.word as naw\n",
    "import nlpaug.flow as naf\n",
    "\n",
    "from nlpaug.util import Action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token:['The', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']\n"
     ]
    }
   ],
   "source": [
    "text = 'The quick brown fox jumps over the lazy dog'\n",
    "tokens = text.split(' ')\n",
    "print('Token:{}'.format(tokens))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Character Augmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nlpaug.augmenter.char import CharAugmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The --> THe\n"
     ]
    }
   ],
   "source": [
    "class CustomCharAug(CharAugmenter):\n",
    "    def __init__(self, name='CustChar_Aug', aug_min=1, aug_p=0.3, tokenizer=None):\n",
    "        super(CustomCharAug, self).__init__(\n",
    "            action=Action.SUBSTITUTE, name=name, aug_p=aug_p, aug_min=aug_min, tokenizer=tokenizer)\n",
    "\n",
    "        self.model = self.get_model()\n",
    "\n",
    "    def substitute(self, tokens):\n",
    "        results = []\n",
    "        for token in tokens:\n",
    "            result = ''\n",
    "            chars = self.tokenizer(token)\n",
    "            aug_cnt = self.generate_aug_cnt(len(chars))\n",
    "            char_idxes = [i for i, char in enumerate(chars)]\n",
    "            aug_idexes = self.sample(char_idxes, aug_cnt)\n",
    "\n",
    "            for i, char in enumerate(chars):\n",
    "                # Skip if no augment for char\n",
    "                if i not in aug_idexes:\n",
    "                    result += char\n",
    "                    continue\n",
    "\n",
    "                # Skip if no mapping\n",
    "                if char not in self.model or len(self.model[char]) < 1:\n",
    "                    result += char\n",
    "                    continue\n",
    "\n",
    "                result += self.sample(self.model[char], 1)[0]\n",
    "\n",
    "            results.append(result)\n",
    "\n",
    "        return results\n",
    "\n",
    "    def get_model(self):\n",
    "        result = {\n",
    "            'T': 't',\n",
    "            'h': 'H',\n",
    "            'e': 'E'\n",
    "        }\n",
    "        return result\n",
    "\n",
    "aug = CustomCharAug()\n",
    "\n",
    "for token in tokens:\n",
    "    print('{} --> {}'.format(token, aug.augment([token])[0]))\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Word Augmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nlpaug.augmenter.word import WordAugmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The --> Custom1\n",
      "quick --> Custom1\n",
      "brown --> Custom1\n",
      "fox --> Custom2\n",
      "jumps --> Custom1\n",
      "over --> Custom2\n",
      "the --> Custom2\n",
      "lazy --> Custom1\n",
      "dog --> Custom1\n"
     ]
    }
   ],
   "source": [
    "class CustomWordAug(WordAugmenter):\n",
    "    def __init__(self, name='CustomWord_Aug', aug_min=1, aug_p=0.3, tokenizer=None):\n",
    "        super(CustomWordAug, self).__init__(\n",
    "            action=Action.INSERT, name=name, aug_p=aug_p, aug_min=aug_min, tokenizer=tokenizer)\n",
    "        \n",
    "        self.model = self.get_model()\n",
    "\n",
    "    def insert(self, tokens):\n",
    "        \"\"\"\n",
    "        :param tokens: list of token\n",
    "        :return: list of token\n",
    "        \"\"\"\n",
    "        results = tokens.copy()\n",
    "\n",
    "        aug_cnt = self.generate_aug_cnt(len(tokens))\n",
    "        word_idxes = [i for i, t in enumerate(tokens)]\n",
    "        aug_idexes = self.sample(word_idxes, aug_cnt)\n",
    "        aug_idexes.sort(reverse=True)\n",
    "\n",
    "        for aug_idx in aug_idexes:\n",
    "            new_word = self.sample(self.model, 1)[0]\n",
    "            results.insert(aug_idx, new_word)\n",
    "\n",
    "        return results\n",
    "    \n",
    "    def get_model(self):\n",
    "        return ['Custom1', 'Custom2']\n",
    "        \n",
    "        \n",
    "\n",
    "aug = CustomWordAug()\n",
    "\n",
    "for token in tokens:\n",
    "    print('{} --> {}'.format(token, aug.augment([token])[0]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
